import torch.nn as nn
from transformers import BertModel
from model.encoder.personalized_bert import EventBertModel


class BertEncoder(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertEncoder, self).__init__()

        if config.getboolean('model', 'use_event') or config.getboolean('model', 'use_event_type'):
            print('\nusing EDBERT (Event Detection BERT)')
            print('\tuse event: ', config.getboolean('model', 'use_event'))
            print('\tuse event type: ', config.getboolean('model', 'use_event_type'))
            model = EventBertModel.from_pretrained(config.get('model', 'bert_path'))
        else:
            print('\nusing original BERT\n')
            model = BertModel.from_pretrained(config.get("model", "bert_path"))

        self.bert = model

    def forward(self, x):
        outputs = self.bert(**x)
        return outputs['pooler_output']
